import numpy as np
import matplotlib.pyplot as plot

#张量的13种操作
'''
param expr:      call_tensor.Expr 类的对象 
param op:        call_tensor.Expr.OP 类的对象
param args:      经过 Input或Input2d处理后，数据被放到args中进行计算 
param **kwargs:  Input或Input2d 操作将操作数变量以字典形式传入

'''
#接收数据输入并进行处理
def Input(expr, op, args, **kwargs):
    if op.name in kwargs:
        c  = kwargs[op.name]
        if isinstance(c, (int, float)):
            return float(c)
        elif hasattr(c, "shape"):
            return c.astype(float)
        else:
            raise Exception("%s: Input must be float or int or ndarray: %s" % (expr, c))
    else:
        raise Exception("%s: missing input" % expr)

def Input2d(expr, op, args, **kwargs):
    if not op.name in kwargs:
            raise Exception("%s: missing input" % expr)
    imgs = kwargs[op.name]
    if not hasattr(imgs, "shape"):
        raise Exception("%s: Input must be ndarray: %s" % (expr, imgs))
    if any([
        len(imgs.shape) != 4,
        imgs.shape[1] != op.parameters["height"],
        imgs.shape[2] != op.parameters["width"],
        imgs.shape[3] != op.parameters["in_channels"]]):
        raise Exception("%s: Invalid input size: %s" % (expr, imgs.shape))
    # NHWC => NCHW
    a = imgs.astype(float).transpose(0,3,1,2)
    return imgs.astype(float).transpose(0,3,1,2)


def Const(expr, op, args, **kwargs):
    return op.parameters["value"]


def Neg(expr, op, args, **kwargs):
    return -args[0]


#tensor张量的Add计算 
def Add(expr, op, args, **kwargs):
    a = args[0]
    b = args[1]
    if not hasattr(a, "shape") and not hasattr(b, "shape"):
        return a+b
    elif hasattr(a, "shape") and hasattr(b, "shape"):
        if a.shape != b.shape:
            raise Exception("%s: size mismatch: %s+%s" % (expr, a.shape, b.shape))
        return a+b
    else:
        raise Exception("%s: cannot mix scalar and ndarray" % expr)

#tensor张量的减法计算 
def Sub(expr, op, args, **kwargs):
    a = args[0]
    b = args[1]
    if not hasattr(a, "shape") and not hasattr(b, "shape"):
        return a-b
    elif hasattr(a, "shape") and hasattr(b, "shape"):
        if a.shape != b.shape:
            raise Exception("%s: size mismatch: %s-%s" % (expr, a.shape, b.shape))
        return a-b
    else:
        raise Exception("%s: cannot mix scalar and ndarray" % expr)

#tensor张量的乘法计算
def Mul(expr, op, args, **kwargs):
    a = args[0]
    b = args[1]
    if not hasattr(a, "shape") or not hasattr(b, "shape"):
        return a*b
    else:
        if len(a.shape) != 2 or len(b.shape) != 2:
            raise Exception("%s: matmul only: %s*%s" % (expr, a.shape, b.shape))
        if a.shape[1] != b.shape[0]:
            raise Exception("%s: size mismatch: %s*%s" % (expr, a.shape, b.shape))
        return np.matmul(a, b)

#拉平
def Flatten(expr, op, args, **kwargs):
    x = args[0]
    if not hasattr(x, "shape"):
        raise Exception("%s: ndarray only: %s" % (expr, imgs))
    a = x.reshape((x.shape[0], -1))
    return x.reshape((x.shape[0], -1))

#ReLU
def ReLU(expr, op, args, **kwargs):
    x = args[0]
    return x*(x > 0)

#线性操作
def Linear(expr, op, args, **kwargs):
    x = args[0]
    if not hasattr(x, "shape"):
        raise Exception("%s: ndarray only: %s" % (expr, x))
    if "weight" not in op.parameters or "bias" not in op.parameters:
        raise Exception("%s: missing weight or bias" % expr)
    weight = op.parameters["weight"]
    bias = op.parameters["bias"]
    if not hasattr(weight, "shape") or not hasattr(bias, "shape"):
        raise Exception("%s: ndarray only for weight or bias" % expr)
    in_features = op.parameters["in_features"]
    out_features = op.parameters["out_features"]
    if any([
        len(x.shape) != 2,
        x.shape[1] != in_features,
        weight.shape != (out_features, in_features),
        bias.shape != (out_features,)]):
        raise Exception("%s: size mismatch: %s*%s+%s" % (expr, weight.shape, x.shape, bias.shape))
    a = np.einsum("ni,oi->no", x, weight)+bias.reshape((1, out_features))
    return np.einsum("ni,oi->no", x, weight)+bias.reshape((1, out_features))

#二维最大池化
def MaxPool2d(expr, op, args, **kwargs):
    x = args[0]
    if not hasattr(x, "shape"):
        raise Exception("%s: ndarray only: %s" % (expr, x))
    kernel_size = op.parameters["kernel_size"]
    stride = op.parameters["stride"]
    if kernel_size != stride:
        raise Exception("%s: kernel_size != stride" % expr)
    if any([
        len(x.shape) != 4,
        x.shape[2]%stride != 0,
        x.shape[3]%stride != 0]):
        raise Exception("%s: size mismatch: %s" % (expr, x.shape))
    new_shape = (x.shape[0], x.shape[1], x.shape[2]//stride, stride, x.shape[3]//stride, stride)
    a = np.nanmax(x.reshape(new_shape), axis = (3,5))
    return np.nanmax(x.reshape(new_shape), axis = (3,5))


def Conv2d(expr, op, args, **kwargs):
    x = args[0]
    if not hasattr(x, "shape"):
        raise Exception("%s: ndarray only: %s" % (expr, x))
    if "weight" not in op.parameters or "bias" not in op.parameters:
        raise Exception("%s: missing weight or bias" % expr)
    weight = op.parameters["weight"]
    bias = op.parameters["bias"]
    in_channels = op.parameters["in_channels"]
    out_channels = op.parameters["out_channels"]
    kernel_size = op.parameters["kernel_size"]
    padding = op.parameters["padding"]
    if any([
        len(x.shape) != 4,
        x.shape[1] != in_channels,
        weight.shape != (out_channels, in_channels, kernel_size, kernel_size),
        bias.shape != (out_channels,)]):
        raise Exception("%s: size mismatch: %s" % (expr, x.shape))
    if padding != 0:
        tmp = np.zeros((x.shape[0], x.shape[1], x.shape[2]+2*padding, x.shape[3]+2*padding))
        tmp[:, :, 1:-2, 1:-2] = x
        x = tmp
    conv_shape = x.shape[:2]+(x.shape[2]+1-kernel_size, x.shape[3]+1-kernel_size, kernel_size, kernel_size)
    conv_strides = x.strides+x.strides[2:]
    conv = np.lib.stride_tricks.as_strided(x, shape = conv_shape, strides = conv_strides, writeable = False)
    a = bias.reshape((1, out_channels, 1, 1))
    b = np.einsum("nihwyx,oiyx->nohw", conv, weight)
    c = np.einsum("nihwyx,oiyx->nohw", conv, weight)+bias.reshape((1, out_channels, 1, 1))
    return np.einsum("nihwyx,oiyx->nohw", conv, weight)+bias.reshape((1, out_channels, 1, 1))



def Show(expr, op, args, **kwargs):
    x = args[0]
    if not hasattr(x, "shape") or len(x.shape) != 4:
        raise Exception("%s: 4D ndarray only: %s" % (expr, x))
    rows = x.shape[0]
    if rows > 8:
        rows = 8
    fig, axes = plot.subplots(rows, x.shape[1])
    fig.set_size_inches(16, 1.5*rows)
    for i, axe in enumerate(axes.flat):
        axe.set_xticks([])
        axe.set_yticks([])
        row = i//x.shape[1]
        c = i%x.shape[1]
        image = x[row, c, :, :]
        axe.imshow(image)
        axe.set_title("%d/%d" % (row, c))
    return x





#调用张量操作的类
#调用张量操作，计算结果
class Evalulation:
    #program是一个存储call_tensor.Expr类对象的list
    def  __init__(self,program):
        self.program = program
    
    #**kwargs可以接收多个字典形式的参数,此处调用tensor基本运算进行计算
    def __call__(self,**kwargs):
        values = {}
        #遍历program中的表达式，如果是input则放到value中，如果是操作符则从value中取值
        #进行计算，并把结果存回value
        for expr in self.program:
            args = [values[ex.id] for ex in expr.inputs]
            #判断expr中的 tensor计算类型是否被实现了
            if expr.op.op_type not in globals():
                raise Exception("%s: not implemented" % expr)
            #调用tensor 基本计算
            values[expr.id] = globals()[expr.op.op_type](expr, expr.op, args, **kwargs)
        return values[self.program[-1].id]
        
  
#收集计算表达式 call_tensor.Expr,追加到program
class Builder:
    def __init__(self):
        self.program = []
    
    def append(self, expr):
        self.program.append(expr)

    def build(self):
        return Evalulation(self.program)
